{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from keras.models import Model\n",
    "from ch10 import construct_seq2seq_model\n",
    "from nlpia.loaders import get_data, DATA_PATH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 64  # Batch size for training.\n",
    "epochs = 100  # Number of epochs to train for.\n",
    "num_samples = 10000\n",
    "data_path = os.path.join(DATA_PATH, 'movie_dialog.txt')  # preprocessed CMU movie dialogue samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    import cPickle as pickle\n",
    "except ImportError:\n",
    "    import pickle\n",
    "\n",
    "from io import open\n",
    "\n",
    "with open(\"../data/characters_stats.pkl\", \"rb\") as filehandler:\n",
    "    input_characters, target_characters, input_token_index, target_token_index = pickle.load(filehandler)\n",
    "\n",
    "with open(\"../data/encoder_decoder_stats.pkl\", \"rb\") as filehandler:\n",
    "    num_encoder_tokens, num_decoder_tokens, max_encoder_seq_length, max_decoder_seq_length = pickle.load(filehandler)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of samples: 10000\n",
      "Number of unique input tokens: 44\n",
      "Number of unique output tokens: 46\n",
      "Max sequence length for inputs: 100\n",
      "Max sequence length for outputs: 102\n"
     ]
    }
   ],
   "source": [
    "input_texts = []\n",
    "target_texts = []\n",
    "input_characters = set()\n",
    "target_characters = set()\n",
    "lines = open(data_path).read().split('\\n')\n",
    "for line in lines[: min(num_samples, len(lines) - 1)]:\n",
    "    input_text, target_text = line.split('\\t')\n",
    "    # We use \"tab\" as the \"start sequence\" character\n",
    "    # for the targets, and \"\\n\" as \"end sequence\" character.\n",
    "    target_text = '\\t' + target_text + '\\n'\n",
    "    input_texts.append(input_text)\n",
    "    target_texts.append(target_text)\n",
    "    for char in input_text:\n",
    "        if char not in input_characters:\n",
    "            input_characters.add(char)\n",
    "    for char in target_text:\n",
    "        if char not in target_characters:\n",
    "            target_characters.add(char)\n",
    "\n",
    "input_characters = sorted(list(input_characters))\n",
    "target_characters = sorted(list(target_characters))\n",
    "num_encoder_tokens = len(input_characters)\n",
    "num_decoder_tokens = len(target_characters)\n",
    "max_encoder_seq_length = max([len(txt) for txt in input_texts])\n",
    "max_decoder_seq_length = max([len(txt) for txt in target_texts])\n",
    "\n",
    "print('Number of samples:', len(input_texts))\n",
    "print('Number of unique input tokens:', num_encoder_tokens)\n",
    "print('Number of unique output tokens:', num_decoder_tokens)\n",
    "print('Max sequence length for inputs:', max_encoder_seq_length)\n",
    "print('Max sequence length for outputs:', max_decoder_seq_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder_input_data = np.zeros(\n",
    "    (len(input_texts), max_encoder_seq_length, num_encoder_tokens),\n",
    "    dtype='float32')\n",
    "decoder_input_data = np.zeros(\n",
    "    (len(input_texts), max_decoder_seq_length, num_decoder_tokens),\n",
    "    dtype='float32')\n",
    "decoder_target_data = np.zeros(\n",
    "    (len(input_texts), max_decoder_seq_length, num_decoder_tokens),\n",
    "    dtype='float32')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):\n",
    "    for t, char in enumerate(input_text):\n",
    "        encoder_input_data[i, t, input_token_index[char]] = 1.\n",
    "    for t, char in enumerate(target_text):\n",
    "        # decoder_target_data is ahead of decoder_input_data by one timestep\n",
    "        decoder_input_data[i, t, target_token_index[char]] = 1.\n",
    "        if t > 0:\n",
    "            # decoder_target_data will be ahead by one timestep\n",
    "            # and will not include the start character.\n",
    "            decoder_target_data[i, t - 1, target_token_index[char]] = 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model = construct_seq2seq_model(num_encoder_tokens, num_decoder_tokens)\n",
    "from keras.layers import Input, LSTM, Dense\n",
    "batch_size = 64    # <1>\n",
    "epochs = 100       # <2>\n",
    "num_neurons = 256  # <3>\n",
    "\n",
    "encoder_inputs = Input(shape=(None, num_encoder_tokens))\n",
    "encoder = LSTM(num_neurons, return_state=True)\n",
    "_, state_h, state_c = encoder(encoder_inputs)\n",
    "\n",
    "encoder_states = [state_h, state_c]\n",
    "\n",
    "decoder_inputs = Input(shape=(None, num_decoder_tokens))\n",
    "decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)\n",
    "decoder_outputs, _, _ = decoder_lstm(decoder_inputs,\n",
    "                                     initial_state=encoder_states)\n",
    "decoder_dense = Dense(num_decoder_tokens, activation='softmax')\n",
    "decoder_outputs = decoder_dense(decoder_outputs)\n",
    "\n",
    "model = Model([encoder_inputs, decoder_inputs], decoder_outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 8000 samples, validate on 2000 samples\n",
      "Epoch 1/100\n",
      "8000/8000 [==============================] - 215s 27ms/step - loss: 1.0202 - acc: 0.0647 - val_loss: 0.8652 - val_acc: 0.0855\n",
      "Epoch 2/100\n",
      "8000/8000 [==============================] - 209s 26ms/step - loss: 0.8518 - acc: 0.1027 - val_loss: 0.7556 - val_acc: 0.1125\n",
      "Epoch 3/100\n",
      "8000/8000 [==============================] - 209s 26ms/step - loss: 0.7695 - acc: 0.1197 - val_loss: 0.7000 - val_acc: 0.1200\n",
      "Epoch 4/100\n",
      "8000/8000 [==============================] - 228s 28ms/step - loss: 0.7238 - acc: 0.1303 - val_loss: 0.6640 - val_acc: 0.1306\n",
      "Epoch 5/100\n",
      "8000/8000 [==============================] - 227s 28ms/step - loss: 0.6899 - acc: 0.1391 - val_loss: 0.6448 - val_acc: 0.1364\n",
      "Epoch 6/100\n",
      "8000/8000 [==============================] - 231s 29ms/step - loss: 0.6647 - acc: 0.1464 - val_loss: 0.6160 - val_acc: 0.1434\n",
      "Epoch 7/100\n",
      "8000/8000 [==============================] - 232s 29ms/step - loss: 0.6423 - acc: 0.1515 - val_loss: 0.6016 - val_acc: 0.1468\n",
      "Epoch 8/100\n",
      "8000/8000 [==============================] - 417s 52ms/step - loss: 0.6233 - acc: 0.1560 - val_loss: 0.5868 - val_acc: 0.1506\n",
      "Epoch 9/100\n",
      "8000/8000 [==============================] - 422s 53ms/step - loss: 0.6069 - acc: 0.1608 - val_loss: 0.5747 - val_acc: 0.1527\n",
      "Epoch 10/100\n",
      "8000/8000 [==============================] - 403s 50ms/step - loss: 0.5927 - acc: 0.1645 - val_loss: 0.5641 - val_acc: 0.1568\n",
      "Epoch 11/100\n",
      "8000/8000 [==============================] - 404s 50ms/step - loss: 0.5798 - acc: 0.1681 - val_loss: 0.5554 - val_acc: 0.1579\n",
      "Epoch 12/100\n",
      "8000/8000 [==============================] - 234s 29ms/step - loss: 0.5688 - acc: 0.1714 - val_loss: 0.5514 - val_acc: 0.1581\n",
      "Epoch 13/100\n",
      "8000/8000 [==============================] - 199s 25ms/step - loss: 0.5583 - acc: 0.1739 - val_loss: 0.5411 - val_acc: 0.1628\n",
      "Epoch 14/100\n",
      "8000/8000 [==============================] - 213s 27ms/step - loss: 0.5499 - acc: 0.1764 - val_loss: 0.5385 - val_acc: 0.1630\n",
      "Epoch 15/100\n",
      "8000/8000 [==============================] - 196s 24ms/step - loss: 0.5403 - acc: 0.1792 - val_loss: 0.5307 - val_acc: 0.1652\n",
      "Epoch 16/100\n",
      "8000/8000 [==============================] - 196s 25ms/step - loss: 0.5320 - acc: 0.1811 - val_loss: 0.5274 - val_acc: 0.1659\n",
      "Epoch 17/100\n",
      "8000/8000 [==============================] - 203s 25ms/step - loss: 0.5244 - acc: 0.1831 - val_loss: 0.5251 - val_acc: 0.1661\n",
      "Epoch 18/100\n",
      "8000/8000 [==============================] - 250s 31ms/step - loss: 0.5170 - acc: 0.1854 - val_loss: 0.5220 - val_acc: 0.1678\n",
      "Epoch 19/100\n",
      "8000/8000 [==============================] - 203s 25ms/step - loss: 0.5102 - acc: 0.1874 - val_loss: 0.5208 - val_acc: 0.1681\n",
      "Epoch 20/100\n",
      "8000/8000 [==============================] - 198s 25ms/step - loss: 0.5036 - acc: 0.1891 - val_loss: 0.5183 - val_acc: 0.1684\n",
      "Epoch 21/100\n",
      "8000/8000 [==============================] - 209s 26ms/step - loss: 0.5064 - acc: 0.1893 - val_loss: 0.5160 - val_acc: 0.1698\n",
      "Epoch 22/100\n",
      "8000/8000 [==============================] - 224s 28ms/step - loss: 0.4918 - acc: 0.1924 - val_loss: 0.5148 - val_acc: 0.1698\n",
      "Epoch 23/100\n",
      "8000/8000 [==============================] - 235s 29ms/step - loss: 0.4858 - acc: 0.1942 - val_loss: 0.5147 - val_acc: 0.1705\n",
      "Epoch 24/100\n",
      "8000/8000 [==============================] - 209s 26ms/step - loss: 0.4800 - acc: 0.1959 - val_loss: 0.5141 - val_acc: 0.1702\n",
      "Epoch 25/100\n",
      "8000/8000 [==============================] - 232s 29ms/step - loss: 0.4746 - acc: 0.1974 - val_loss: 0.5151 - val_acc: 0.1695\n",
      "Epoch 26/100\n",
      "8000/8000 [==============================] - 249s 31ms/step - loss: 0.4690 - acc: 0.1990 - val_loss: 0.5163 - val_acc: 0.1703\n",
      "Epoch 27/100\n",
      "8000/8000 [==============================] - 212s 26ms/step - loss: 0.4636 - acc: 0.2006 - val_loss: 0.5158 - val_acc: 0.1701\n",
      "Epoch 28/100\n",
      "8000/8000 [==============================] - 235s 29ms/step - loss: 0.4676 - acc: 0.2000 - val_loss: 0.5173 - val_acc: 0.1701\n",
      "Epoch 29/100\n",
      "8000/8000 [==============================] - 254s 32ms/step - loss: 0.4545 - acc: 0.2032 - val_loss: 0.5191 - val_acc: 0.1699\n",
      "Epoch 30/100\n",
      "8000/8000 [==============================] - 240s 30ms/step - loss: 0.4512 - acc: 0.2045 - val_loss: 0.5185 - val_acc: 0.1704\n",
      "Epoch 31/100\n",
      "8000/8000 [==============================] - 234s 29ms/step - loss: 0.4463 - acc: 0.2060 - val_loss: 0.5198 - val_acc: 0.1702\n",
      "Epoch 32/100\n",
      "8000/8000 [==============================] - 233s 29ms/step - loss: 0.4415 - acc: 0.2072 - val_loss: 0.5234 - val_acc: 0.1695\n",
      "Epoch 33/100\n",
      "8000/8000 [==============================] - 239s 30ms/step - loss: 0.4356 - acc: 0.2089 - val_loss: 0.5246 - val_acc: 0.1690\n",
      "Epoch 34/100\n",
      "8000/8000 [==============================] - 234s 29ms/step - loss: 0.4300 - acc: 0.2107 - val_loss: 0.5280 - val_acc: 0.1691\n",
      "Epoch 35/100\n",
      "8000/8000 [==============================] - 256s 32ms/step - loss: 0.4249 - acc: 0.2123 - val_loss: 0.5302 - val_acc: 0.1678\n",
      "Epoch 36/100\n",
      "8000/8000 [==============================] - 211s 26ms/step - loss: 0.4200 - acc: 0.2136 - val_loss: 0.5310 - val_acc: 0.1683\n",
      "Epoch 37/100\n",
      "8000/8000 [==============================] - 208s 26ms/step - loss: 0.4153 - acc: 0.2151 - val_loss: 0.5342 - val_acc: 0.1689\n",
      "Epoch 38/100\n",
      "8000/8000 [==============================] - 208s 26ms/step - loss: 0.4104 - acc: 0.2167 - val_loss: 0.5369 - val_acc: 0.1683\n",
      "Epoch 39/100\n",
      "8000/8000 [==============================] - 209s 26ms/step - loss: 0.4062 - acc: 0.2183 - val_loss: 0.5404 - val_acc: 0.1673\n",
      "Epoch 40/100\n",
      "8000/8000 [==============================] - 208s 26ms/step - loss: 0.4011 - acc: 0.2197 - val_loss: 0.5435 - val_acc: 0.1665\n",
      "Epoch 41/100\n",
      "8000/8000 [==============================] - 211s 26ms/step - loss: 0.3967 - acc: 0.2208 - val_loss: 0.5465 - val_acc: 0.1669\n",
      "Epoch 42/100\n",
      "8000/8000 [==============================] - 210s 26ms/step - loss: 0.3925 - acc: 0.2224 - val_loss: 0.5504 - val_acc: 0.1664\n",
      "Epoch 43/100\n",
      "8000/8000 [==============================] - 211s 26ms/step - loss: 0.3881 - acc: 0.2236 - val_loss: 0.5536 - val_acc: 0.1662\n",
      "Epoch 44/100\n",
      "8000/8000 [==============================] - 211s 26ms/step - loss: 0.3838 - acc: 0.2250 - val_loss: 0.5579 - val_acc: 0.1653\n",
      "Epoch 45/100\n",
      "8000/8000 [==============================] - 210s 26ms/step - loss: 0.3796 - acc: 0.2262 - val_loss: 0.5600 - val_acc: 0.1657\n",
      "Epoch 46/100\n",
      "8000/8000 [==============================] - 211s 26ms/step - loss: 0.3756 - acc: 0.2276 - val_loss: 0.5645 - val_acc: 0.1649\n",
      "Epoch 47/100\n",
      "8000/8000 [==============================] - 211s 26ms/step - loss: 0.3720 - acc: 0.2287 - val_loss: 0.5678 - val_acc: 0.1649\n",
      "Epoch 48/100\n",
      "8000/8000 [==============================] - 209s 26ms/step - loss: 0.3677 - acc: 0.2302 - val_loss: 0.5715 - val_acc: 0.1650\n",
      "Epoch 49/100\n",
      "8000/8000 [==============================] - 212s 27ms/step - loss: 0.3639 - acc: 0.2316 - val_loss: 0.5756 - val_acc: 0.1641\n",
      "Epoch 50/100\n",
      "8000/8000 [==============================] - 211s 26ms/step - loss: 0.3603 - acc: 0.2325 - val_loss: 0.5816 - val_acc: 0.1632\n",
      "Epoch 51/100\n",
      "8000/8000 [==============================] - 210s 26ms/step - loss: 0.3567 - acc: 0.2332 - val_loss: 0.5858 - val_acc: 0.1624\n",
      "Epoch 52/100\n",
      "8000/8000 [==============================] - 214s 27ms/step - loss: 0.3903 - acc: 0.2270 - val_loss: 0.5843 - val_acc: 0.1634\n",
      "Epoch 53/100\n",
      "8000/8000 [==============================] - 208s 26ms/step - loss: 0.3545 - acc: 0.2341 - val_loss: 0.5901 - val_acc: 0.1623\n",
      "Epoch 54/100\n",
      "8000/8000 [==============================] - 212s 26ms/step - loss: 0.3482 - acc: 0.2361 - val_loss: 0.5913 - val_acc: 0.1622\n",
      "Epoch 55/100\n",
      "8000/8000 [==============================] - 210s 26ms/step - loss: 0.3748 - acc: 0.2295 - val_loss: 0.6049 - val_acc: 0.1599\n",
      "Epoch 56/100\n",
      "8000/8000 [==============================] - 243s 30ms/step - loss: 0.3662 - acc: 0.2316 - val_loss: 0.5963 - val_acc: 0.1619\n",
      "Epoch 57/100\n",
      "8000/8000 [==============================] - 224s 28ms/step - loss: 0.3423 - acc: 0.2380 - val_loss: 0.5986 - val_acc: 0.1620\n",
      "Epoch 58/100\n",
      "8000/8000 [==============================] - 210s 26ms/step - loss: 0.3385 - acc: 0.2391 - val_loss: 0.6029 - val_acc: 0.1621\n",
      "Epoch 59/100\n",
      "8000/8000 [==============================] - 210s 26ms/step - loss: 0.3355 - acc: 0.2402 - val_loss: 0.6075 - val_acc: 0.1617\n",
      "Epoch 60/100\n",
      "8000/8000 [==============================] - 211s 26ms/step - loss: 0.3323 - acc: 0.2410 - val_loss: 0.6124 - val_acc: 0.1608\n",
      "Epoch 61/100\n",
      "8000/8000 [==============================] - 210s 26ms/step - loss: 0.3296 - acc: 0.2418 - val_loss: 0.6131 - val_acc: 0.1616\n",
      "Epoch 62/100\n",
      "8000/8000 [==============================] - 210s 26ms/step - loss: 0.3268 - acc: 0.2427 - val_loss: 0.6171 - val_acc: 0.1612\n",
      "Epoch 63/100\n",
      "8000/8000 [==============================] - 211s 26ms/step - loss: 0.3243 - acc: 0.2435 - val_loss: 0.6229 - val_acc: 0.1597\n",
      "Epoch 64/100\n",
      "8000/8000 [==============================] - 210s 26ms/step - loss: 0.3222 - acc: 0.2442 - val_loss: 0.6245 - val_acc: 0.1606\n",
      "Epoch 65/100\n",
      "8000/8000 [==============================] - 211s 26ms/step - loss: 0.3195 - acc: 0.2448 - val_loss: 0.6269 - val_acc: 0.1601\n",
      "Epoch 66/100\n",
      "8000/8000 [==============================] - 214s 27ms/step - loss: 0.3173 - acc: 0.2457 - val_loss: 0.6319 - val_acc: 0.1601\n",
      "Epoch 67/100\n",
      "8000/8000 [==============================] - 212s 26ms/step - loss: 0.3152 - acc: 0.2462 - val_loss: 0.6350 - val_acc: 0.1595\n",
      "Epoch 68/100\n",
      "8000/8000 [==============================] - 210s 26ms/step - loss: 0.3130 - acc: 0.2469 - val_loss: 0.6387 - val_acc: 0.1601\n",
      "Epoch 69/100\n",
      "8000/8000 [==============================] - 214s 27ms/step - loss: 0.3107 - acc: 0.2478 - val_loss: 0.6444 - val_acc: 0.1597\n",
      "Epoch 70/100\n",
      "8000/8000 [==============================] - 210s 26ms/step - loss: 0.3085 - acc: 0.2484 - val_loss: 0.6452 - val_acc: 0.1599\n",
      "Epoch 71/100\n",
      "8000/8000 [==============================] - 212s 26ms/step - loss: 0.3076 - acc: 0.2485 - val_loss: 0.6487 - val_acc: 0.1596\n",
      "Epoch 72/100\n",
      "8000/8000 [==============================] - 228s 28ms/step - loss: 0.3049 - acc: 0.2493 - val_loss: 0.6526 - val_acc: 0.1592\n",
      "Epoch 73/100\n",
      "8000/8000 [==============================] - 211s 26ms/step - loss: 0.3037 - acc: 0.2499 - val_loss: 0.6538 - val_acc: 0.1593\n",
      "Epoch 74/100\n",
      "8000/8000 [==============================] - 212s 26ms/step - loss: 0.3016 - acc: 0.2503 - val_loss: 0.6566 - val_acc: 0.1591\n",
      "Epoch 75/100\n",
      "8000/8000 [==============================] - 209s 26ms/step - loss: 0.3000 - acc: 0.2506 - val_loss: 0.6581 - val_acc: 0.1585\n",
      "Epoch 76/100\n",
      "8000/8000 [==============================] - 211s 26ms/step - loss: 0.2991 - acc: 0.2513 - val_loss: 0.6610 - val_acc: 0.1596\n",
      "Epoch 77/100\n",
      "8000/8000 [==============================] - 213s 27ms/step - loss: 0.2966 - acc: 0.2519 - val_loss: 0.6668 - val_acc: 0.1584\n",
      "Epoch 78/100\n",
      "8000/8000 [==============================] - 211s 26ms/step - loss: 0.2948 - acc: 0.2524 - val_loss: 0.6680 - val_acc: 0.1586\n",
      "Epoch 79/100\n",
      "8000/8000 [==============================] - 209s 26ms/step - loss: 0.2934 - acc: 0.2529 - val_loss: 0.6721 - val_acc: 0.1579\n",
      "Epoch 80/100\n",
      "8000/8000 [==============================] - 211s 26ms/step - loss: 0.2921 - acc: 0.2531 - val_loss: 0.6754 - val_acc: 0.1584\n",
      "Epoch 81/100\n",
      "8000/8000 [==============================] - 211s 26ms/step - loss: 0.2907 - acc: 0.2535 - val_loss: 0.6775 - val_acc: 0.1585\n",
      "Epoch 82/100\n",
      "8000/8000 [==============================] - 211s 26ms/step - loss: 0.2892 - acc: 0.2538 - val_loss: 0.6828 - val_acc: 0.1580\n",
      "Epoch 83/100\n",
      "8000/8000 [==============================] - 212s 26ms/step - loss: 0.2880 - acc: 0.2543 - val_loss: 0.6824 - val_acc: 0.1577\n",
      "Epoch 84/100\n",
      "8000/8000 [==============================] - 210s 26ms/step - loss: 0.2866 - acc: 0.2548 - val_loss: 0.6864 - val_acc: 0.1577\n",
      "Epoch 85/100\n",
      "8000/8000 [==============================] - 211s 26ms/step - loss: 0.2848 - acc: 0.2557 - val_loss: 0.6883 - val_acc: 0.1578\n",
      "Epoch 86/100\n",
      "8000/8000 [==============================] - 233s 29ms/step - loss: 0.2835 - acc: 0.2557 - val_loss: 0.6874 - val_acc: 0.1573\n",
      "Epoch 87/100\n",
      "8000/8000 [==============================] - 210s 26ms/step - loss: 0.2826 - acc: 0.2561 - val_loss: 0.6908 - val_acc: 0.1581\n",
      "Epoch 88/100\n",
      "8000/8000 [==============================] - 212s 26ms/step - loss: 0.2808 - acc: 0.2566 - val_loss: 0.6974 - val_acc: 0.1574\n",
      "Epoch 89/100\n",
      "8000/8000 [==============================] - 211s 26ms/step - loss: 0.2799 - acc: 0.2567 - val_loss: 0.6991 - val_acc: 0.1575\n",
      "Epoch 90/100\n",
      "8000/8000 [==============================] - 211s 26ms/step - loss: 0.2789 - acc: 0.2570 - val_loss: 0.7016 - val_acc: 0.1574\n",
      "Epoch 91/100\n",
      "8000/8000 [==============================] - 210s 26ms/step - loss: 0.2781 - acc: 0.2573 - val_loss: 0.7036 - val_acc: 0.1569\n",
      "Epoch 92/100\n",
      "8000/8000 [==============================] - 210s 26ms/step - loss: 0.2760 - acc: 0.2581 - val_loss: 0.7062 - val_acc: 0.1571\n",
      "Epoch 93/100\n",
      "8000/8000 [==============================] - 211s 26ms/step - loss: 0.2752 - acc: 0.2583 - val_loss: 0.7118 - val_acc: 0.1568\n",
      "Epoch 94/100\n",
      "8000/8000 [==============================] - 210s 26ms/step - loss: 0.2744 - acc: 0.2585 - val_loss: 0.7104 - val_acc: 0.1575\n",
      "Epoch 95/100\n",
      "8000/8000 [==============================] - 212s 26ms/step - loss: 0.2732 - acc: 0.2588 - val_loss: 0.7117 - val_acc: 0.1569\n",
      "Epoch 96/100\n",
      "8000/8000 [==============================] - 209s 26ms/step - loss: 0.2718 - acc: 0.2593 - val_loss: 0.7139 - val_acc: 0.1576\n",
      "Epoch 97/100\n",
      "8000/8000 [==============================] - 210s 26ms/step - loss: 0.2707 - acc: 0.2593 - val_loss: 0.7166 - val_acc: 0.1571\n",
      "Epoch 98/100\n",
      "8000/8000 [==============================] - 213s 27ms/step - loss: 0.2699 - acc: 0.2597 - val_loss: 0.7189 - val_acc: 0.1569\n",
      "Epoch 99/100\n",
      "8000/8000 [==============================] - 211s 26ms/step - loss: 0.2686 - acc: 0.2601 - val_loss: 0.7236 - val_acc: 0.1571\n",
      "Epoch 100/100\n",
      "8000/8000 [==============================] - 211s 26ms/step - loss: 0.2686 - acc: 0.2600 - val_loss: 0.7234 - val_acc: 0.1561\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0xb61cdd0b8>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Run training\n",
    "model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['acc'])\n",
    "model.fit([encoder_input_data, decoder_input_data], decoder_target_data,\n",
    "          batch_size=batch_size,\n",
    "          epochs=epochs,\n",
    "          validation_split=0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/hobsonlane/anaconda3/envs/nlpiaenv/lib/python3.6/site-packages/keras/engine/network.py:877: UserWarning: Layer lstm_2 was passed non-serializable keyword arguments: {'initial_state': [<tf.Tensor 'lstm_1/while/Exit_2:0' shape=(?, 256) dtype=float32>, <tf.Tensor 'lstm_1/while/Exit_3:0' shape=(?, 256) dtype=float32>]}. They will not be included in the serialized model (and thus will be missing at deserialization time).\n",
      "  '. They will not be included '\n"
     ]
    }
   ],
   "source": [
    "model_path = os.path.join(DATA_PATH, 'ch10_train_seq2seq_keras.h5')\n",
    "model.save(model_path + '_model')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_weights(model_path + '_weights.h5')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Inference/Activation without training\n",
    "Everything below can be run without rerunning training  \n",
    "TODO: put these cells in a separate notebook named ch10_inference_...  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.models import load_model\n",
    "\n",
    "model_path = os.path.join(DATA_PATH, 'ch10_train_seq2seq_keras')\n",
    "model = load_model(model_path + '_model.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_weights(model_path + '_weights.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/hobsonlane/anaconda3/envs/nlpiaenv/lib/python3.6/site-packages/ipykernel_launcher.py:11: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=[<tf.Tenso..., outputs=[<tf.Tenso...)`\n",
      "  # This is added back by InteractiveShellApp.init_path()\n"
     ]
    }
   ],
   "source": [
    "encoder_model = Model(encoder_inputs, encoder_states)\n",
    "thought_input = [\n",
    "    Input(shape=(num_neurons,)), Input(shape=(num_neurons,))]\n",
    "decoder_outputs, state_h, state_c = decoder_lstm(\n",
    "    decoder_inputs, initial_state=thought_input)\n",
    "decoder_states = [state_h, state_c]\n",
    "decoder_outputs = decoder_dense(decoder_outputs)\n",
    "\n",
    "decoder_model = Model(\n",
    "    inputs=[decoder_inputs] + thought_input,\n",
    "    output=[decoder_outputs] + decoder_states)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    ">>> def decode_sequence(input_seq):\n",
    "...     thought = encoder_model.predict(input_seq)  # <1>\n",
    "\n",
    "...     target_seq = np.zeros((1, 1, output_vocab_size))  # <2>\n",
    "...     target_seq[0, 0, target_token_index[stop_token]\n",
    "...         ] = 1.  # <3>\n",
    "...     stop_condition = False\n",
    "...     generated_sequence = ''\n",
    "\n",
    "...     while not stop_condition:\n",
    "...         output_tokens, h, c = decoder_model.predict(\n",
    "...             [target_seq] + thought) # <4>\n",
    "\n",
    "...         generated_token_idx = np.argmax(output_tokens[0, -1, :])\n",
    "...         generated_char = reverse_target_char_index[generated_token_idx]\n",
    "...         generated_sequence += generated_char\n",
    "...         if (generated_char == stop_token or\n",
    "...                 len(generated_sequence) > max_decoder_seq_length\n",
    "...                 ):  # <5>\n",
    "...             stop_condition = True\n",
    "\n",
    "...         target_seq = np.zeros((1, 1, output_vocab_size))  # <6>\n",
    "...         target_seq[0, 0, generated_token_idx] = 1.\n",
    "...         thought = [h, c]  # <7>\n",
    "\n",
    "...     return generated_sequence"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
